{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.TabModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TabModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an implementation created by Ignacio Oguiza (oguiza@timeseriesAI.co) based on fastai's TabularModel. \n",
    "\n",
    "We built it so that it's easy to change the head of the model, something that is particularly interesting when building hybrid models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "# from tsai.data.tabular import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TabModel(Sequential): # Sequential accepts multiple inputs\n",
    "    \"Basic model for tabular data.\"\n",
    "    def __init__(self, emb_szs, n_cont, c_out, layers=None, fc_dropout=None, embed_p=0., y_range=None, use_bn=True, bn_final=False, bn_cont=True,\n",
    "                 lin_first=False, act=nn.ReLU(inplace=True), skip=False):\n",
    "\n",
    "        # Backbone\n",
    "        backbone = TabBackbone(emb_szs, n_cont, embed_p=embed_p, bn_cont=bn_cont)\n",
    "        \n",
    "        # Head\n",
    "        head = TabHead(emb_szs, n_cont, c_out, layers=layers, fc_dropout=fc_dropout, y_range=y_range, use_bn=use_bn, bn_final=bn_final, lin_first=lin_first, \n",
    "                 act=act, skip=skip)\n",
    "\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "\n",
    "\n",
    "class TabBackbone(Module):\n",
    "    def __init__(self, emb_szs, n_cont, embed_p=0., bn_cont=True):\n",
    "        self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs])\n",
    "        self.emb_drop = nn.Dropout(embed_p)\n",
    "        self.bn_cont = nn.BatchNorm1d(n_cont) if bn_cont else None\n",
    "        n_emb = sum(e.embedding_dim for e in self.embeds)\n",
    "        self.n_emb,self.n_cont = n_emb,n_cont\n",
    "\n",
    "    def forward(self, x_cat, x_cont=None):\n",
    "        if self.n_emb != 0:\n",
    "            x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]\n",
    "            x = torch.cat(x, 1)\n",
    "            x = self.emb_drop(x)\n",
    "        if self.n_cont != 0:\n",
    "            if self.bn_cont is not None: x_cont = self.bn_cont(x_cont)\n",
    "            x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont\n",
    "        return x\n",
    "\n",
    "\n",
    "class TabHead(Module):\n",
    "    \"Basic head for tabular data.\"\n",
    "    def __init__(self, emb_szs, n_cont, c_out, layers=None, fc_dropout=None, y_range=None, use_bn=True, bn_final=False, lin_first=False, \n",
    "                 act=nn.ReLU(inplace=True), skip=False):\n",
    "        \n",
    "        # Head\n",
    "        if layers is None: layers = [200,100]\n",
    "        ps = ifnone(fc_dropout, [0]*len(layers))\n",
    "        if not is_listy(ps): ps = [ps]*len(layers)\n",
    "        n_emb = np.sum([emb_dim for _, emb_dim in emb_szs]).astype(int)\n",
    "        sizes = [n_emb + n_cont] + layers + [c_out]\n",
    "        actns = [act for _ in range(len(sizes)-2)] + [None]\n",
    "        _layers = [LinBnDrop(sizes[i], sizes[i+1], bn=use_bn and (i!=len(actns)-1 or bn_final), p=p, act=a) for i,(p,a) in enumerate(zip(ps+[0.],actns))]\n",
    "        if y_range is not None: _layers.append(SigmoidRange(*y_range))\n",
    "        self.head = nn.Sequential(*_layers)\n",
    "        self.head_nf = layers[-1]\n",
    "        self.shortcut = nn.Linear(n_emb + n_cont, c_out) if skip else None\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.shortcut is not None: res = x\n",
    "        x = self.head(x)\n",
    "        if self.shortcut is not None: x = x + self.shortcut(res)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.tabular.core import *\n",
    "from tsai.data.tabular import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>education-num_na</th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education-num</th>\n",
       "      <th>salary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>20505</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Sales</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>47.0</td>\n",
       "      <td>197836.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28679</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Craft-repair</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>28.0</td>\n",
       "      <td>65078.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&gt;=50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11669</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>38.0</td>\n",
       "      <td>202683.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29079</th>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>41.0</td>\n",
       "      <td>168098.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7061</th>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>31.0</td>\n",
       "      <td>243442.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.ADULT_SAMPLE)\n",
    "df = pd.read_csv(path/'adult.csv')\n",
    "# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable\n",
    "procs = [Categorify, FillMissing, Normalize]\n",
    "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
    "cont_names = ['age', 'fnlwgt', 'education-num']\n",
    "y_names = ['salary']\n",
    "y_block = RegressionBlock() if isinstance(df['salary'].values[0], float) else CategoryBlock()\n",
    "splits = RandomSplitter()(range_of(df))\n",
    "pd.options.mode.chained_assignment=None\n",
    "to = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names, y_names=y_names, y_block=y_block, splits=splits, inplace=True, \n",
    "                   reduce_memory=False)\n",
    "to.show(5)\n",
    "tab_dls = to.dataloaders(bs=16, val_bs=32)\n",
    "b = first(tab_dls.train)\n",
    "test_eq((b[0].shape, b[1].shape, b[2].shape), (torch.Size([16, 7]), torch.Size([16, 3]), torch.Size([16, 1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_model = build_tabular_model(TabModel, dls=tab_dls)\n",
    "b = first(tab_dls.train)\n",
    "test_eq(tab_model.to(b[0].device)(*b[:-1]).shape, (tab_dls.bs, tab_dls.c))\n",
    "learn = Learner(tab_dls, tab_model, splitter=ts_splitter)\n",
    "p1 = count_parameters(learn.model)\n",
    "learn.freeze()\n",
    "p2 = count_parameters(learn.model)\n",
    "learn.unfreeze()\n",
    "p3 = count_parameters(learn.model)\n",
    "assert p1 == p3\n",
    "assert p1 > p2 > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/120_models.TabModel.ipynb saved at 2022-11-09 13:12:57\n",
      "Correct notebook to script conversion! 😃\n",
      "Wednesday 09/11/22 13:12:59 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
